// Copyright (c) 2014-2017 Statoshi Developers
// Copyright (c) 2017-2023 Vincent Thiery
// Copyright (c) 2020-2024 The Dash Core developers
// Distributed under the MIT software license, see the accompanying
// file COPYING or http://www.opensource.org/licenses/mit-license.php.

#include <stats/client.h>

#include <logging.h>
#include <random.h>
#include <stats/rawsender.h>
#include <sync.h>
#include <util/check.h>
#include <util/strencodings.h>
#include <util/system.h>
#include <util/translation.h>

#include <algorithm>
#include <cmath>
#include <limits>
#include <random>

namespace {
/** Threshold below which a value is considered effectively zero */
static constexpr float EPSILON{0.0001f};
/** Delimiter segmenting scheme from the rest of the URL */
static constexpr std::string_view URL_SCHEME_DELIMITER{"://"};

/** Default port used to connect to a Statsd server */
static constexpr uint16_t DEFAULT_STATSD_PORT{8125};

/** Delimiter segmenting two fully formed Statsd messages */
static constexpr char STATSD_MSG_DELIMITER{'\n'};
/** Delimiter segmenting namespaces in a Statsd key */
static constexpr char STATSD_NS_DELIMITER{'.'};
/** Character used to denote Statsd message type as count */
static constexpr char STATSD_METRIC_COUNT[]{"c"};
/** Character used to denote Statsd message type as gauge */
static constexpr char STATSD_METRIC_GAUGE[]{"g"};
/** Characters used to denote Statsd message type as timing */
static constexpr char STATSD_METRIC_TIMING[]{"ms"};

class StatsdClientImpl final : public StatsdClient
{
public:
    explicit StatsdClientImpl(const std::string& host, uint16_t port, uint64_t batch_size, uint64_t interval_ms,
                              const std::string& prefix, const std::string& suffix, std::optional<bilingual_str>& error);
    ~StatsdClientImpl() = default;

public:
    bool dec(std::string_view key, float sample_rate) override EXCLUSIVE_LOCKS_REQUIRED(!cs)
    {
        return count(key, -1, sample_rate);
    }
    bool inc(std::string_view key, float sample_rate) override EXCLUSIVE_LOCKS_REQUIRED(!cs)
    {
        return count(key, 1, sample_rate);
    }
    bool count(std::string_view key, int64_t delta, float sample_rate) override EXCLUSIVE_LOCKS_REQUIRED(!cs)
    {
        return _send(key, delta, STATSD_METRIC_COUNT, sample_rate);
    }
    bool gauge(std::string_view key, int64_t value, float sample_rate) override EXCLUSIVE_LOCKS_REQUIRED(!cs)
    {
        return _send(key, value, STATSD_METRIC_GAUGE, sample_rate);
    }
    bool gaugeDouble(std::string_view key, double value, float sample_rate) override EXCLUSIVE_LOCKS_REQUIRED(!cs)
    {
        return _send(key, value, STATSD_METRIC_GAUGE, sample_rate);
    }
    bool timing(std::string_view key, uint64_t ms, float sample_rate) override EXCLUSIVE_LOCKS_REQUIRED(!cs)
    {
        return _send(key, ms, STATSD_METRIC_TIMING, sample_rate);
    }

    bool send(std::string_view key, double value, std::string_view type, float sample_rate) override
        EXCLUSIVE_LOCKS_REQUIRED(!cs)
    {
        return _send(key, value, type, sample_rate);
    }
    bool send(std::string_view key, int32_t value, std::string_view type, float sample_rate) override
        EXCLUSIVE_LOCKS_REQUIRED(!cs)
    {
        return _send(key, value, type, sample_rate);
    }
    bool send(std::string_view key, int64_t value, std::string_view type, float sample_rate) override
        EXCLUSIVE_LOCKS_REQUIRED(!cs)
    {
        return _send(key, value, type, sample_rate);
    }
    bool send(std::string_view key, uint32_t value, std::string_view type, float sample_rate) override
        EXCLUSIVE_LOCKS_REQUIRED(!cs)
    {
        return _send(key, value, type, sample_rate);
    }
    bool send(std::string_view key, uint64_t value, std::string_view type, float sample_rate) override
        EXCLUSIVE_LOCKS_REQUIRED(!cs)
    {
        return _send(key, value, type, sample_rate);
    }

    bool active() const override { return m_sender != nullptr; }

private:
    template <typename T1>
    inline bool _send(std::string_view key, T1 value, std::string_view type, float sample_rate)
        EXCLUSIVE_LOCKS_REQUIRED(!cs);

private:
    /* Mutex to protect PRNG */
    mutable Mutex cs;
    /* PRNG used to dice-roll messages that are 0 < f < 1 */
    mutable FastRandomContext insecure_rand GUARDED_BY(cs);

    /* Broadcasts messages crafted by StatsdClient */
    std::unique_ptr<RawSender> m_sender{nullptr};

    /* Phrase prepended to keys */
    const std::string m_prefix{};
    /* Phrase appended to keys */
    const std::string m_suffix{};
};
} // anonymous namespace

std::unique_ptr<StatsdClient> g_stats_client;

util::Result<std::unique_ptr<StatsdClient>> StatsdClient::make(const ArgsManager& args)
{
    auto host = args.GetArg("-statshost", DEFAULT_STATSD_HOST);
    if (host.empty()) {
        LogPrintf("Transmitting stats are disabled, will not init Statsd client\n");
        return std::make_unique<StatsdClient>();
    }

    const int64_t batch_size = args.GetIntArg("-statsbatchsize", DEFAULT_STATSD_BATCH_SIZE);
    if (batch_size < 0) {
        return util::Error{_("-statsbatchsize cannot be configured with a negative value.")};
    }

    const int64_t interval_ms = args.GetIntArg("-statsduration", DEFAULT_STATSD_DURATION);
    if (interval_ms < 0) {
        return util::Error{_("-statsduration cannot be configured with a negative value.")};
    }

    auto port_arg = args.GetIntArg("-statsport", DEFAULT_STATSD_PORT);
    if (args.IsArgSet("-statsport")) {
        // Port range validation if -statsport is specified.
        if (port_arg < 1 || port_arg > std::numeric_limits<uint16_t>::max()) {
            return util::Error{strprintf(_("Port must be between %d and %d, supplied %d"), 1,
                                         std::numeric_limits<uint16_t>::max(), port_arg)};
        }
    }
    uint16_t port = static_cast<uint16_t>(port_arg);

    // Could be a URL, try to parse it.
    const size_t scheme_idx{host.find(URL_SCHEME_DELIMITER)};
    if (scheme_idx != std::string::npos) {
        // Parse the scheme and trim it out of the URL if we succeed
        if (scheme_idx == 0) {
            return util::Error{_("No text before the scheme delimiter, malformed URL")};
        }
        std::string scheme{ToLower(host.substr(/*pos=*/0, scheme_idx))};
        if (scheme != "udp") {
            return util::Error{_("Unsupported URL scheme, must begin with udp://")};
        }
        host = host.substr(scheme_idx + URL_SCHEME_DELIMITER.length());

        // Strip trailing slashes and parse the port
        const size_t colon_idx{host.rfind(':')};
        if (colon_idx != std::string::npos) {
            // Remove all forward slashes found after the port delimiter (colon)
            host = std::string(
                host.begin(), host.end() - [&colon_idx, &host]() {
                    const size_t slash_idx{host.find('/', /*pos=*/colon_idx + 1)};
                    return slash_idx != std::string::npos ? host.length() - slash_idx : 0;
                }());
            uint16_t port_url{0};
            SplitHostPort(host, port_url, host);
            if (port_url != 0) {
                if (args.IsArgSet("-statsport")) {
                    LogPrintf("%s: Supplied URL with port, ignoring -statsport\n", __func__);
                }
                port = port_url;
            }
        } else {
            // There was no port specified, remove everything after the first forward slash
            host = host.substr(/*pos=*/0, host.find("/"));
        }

        if (host.empty()) {
            return util::Error{_("No host specified, malformed URL")};
        }
    }

    auto sanitize_string = [](std::string string) {
        // Remove key delimiters from the front and back as they're added back in
        // the constructor
        if (!string.empty()) {
            if (string.front() == STATSD_NS_DELIMITER) string.erase(string.begin());
            if (string.back() == STATSD_NS_DELIMITER) string.pop_back();
        }
        return string;
    };

    std::optional<bilingual_str> error_opt;
    auto statsd_ptr = std::make_unique<StatsdClientImpl>(
        host, port, batch_size, interval_ms,
        sanitize_string(args.GetArg("-statsprefix", DEFAULT_STATSD_PREFIX)),
        sanitize_string(args.GetArg("-statssuffix", DEFAULT_STATSD_SUFFIX)), error_opt);
    if (error_opt.has_value()) {
        statsd_ptr.reset();
        return util::Error{error_opt.value()};
    }
    return {std::move(statsd_ptr)};
}

StatsdClientImpl::StatsdClientImpl(const std::string& host, uint16_t port, uint64_t batch_size, uint64_t interval_ms,
                                   const std::string& prefix, const std::string& suffix,
                                   std::optional<bilingual_str>& error) :
    m_prefix{[prefix]() { return !prefix.empty() ? prefix + STATSD_NS_DELIMITER : prefix; }()},
    m_suffix{[suffix]() { return !suffix.empty() ? STATSD_NS_DELIMITER + suffix : suffix; }()}
{
    m_sender = std::make_unique<RawSender>(host, port,
                                           std::make_pair(batch_size, static_cast<uint8_t>(STATSD_MSG_DELIMITER)),
                                           interval_ms, error);
    if (error.has_value()) {
        m_sender.reset();
        return;
    }

    LogPrintf("StatsdClient initialized to transmit stats to %s:%d\n", host, port);
}

template <typename T1>
inline bool StatsdClientImpl::_send(std::string_view key, T1 value, std::string_view type, float sample_rate)
{
    static_assert(std::is_arithmetic<T1>::value, "Must specialize to an arithmetic type");

    // Determine if we should send the message at all but claim that we did even if we don't
    sample_rate = std::clamp(sample_rate, 0.f, 1.f);
    bool always_send = std::fabs(sample_rate - 1.f) < EPSILON;
    bool never_send = std::fabs(sample_rate) < EPSILON;
    if (never_send || (!always_send &&
                       WITH_LOCK(cs, return sample_rate < std::uniform_real_distribution<float>(0.f, 1.f)(insecure_rand)))) {
        return true;
    }

    // Construct the message and if our message isn't always-send, report the sample rate
    RawMessage msg{strprintf("%s%s%s:%f|%s", m_prefix, key, m_suffix, value, type)};
    if (!always_send) {
        msg += strprintf("|@%.2f", sample_rate);
    }

    // Send it and report an error if we encounter one
    if (auto error_opt = Assert(m_sender)->Send(msg); error_opt.has_value()) {
        LogPrintf("ERROR: %s.\n", error_opt->original);
        return false;
    }

    return true;
}
